{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Basic settings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "yM2rpXs7-gdm"
   },
   "outputs": [],
   "source": [
    "# name of the directory containing weights\n",
    "name = \"pseudo_weights\"\n",
    "# name of the individual weights\n",
    "weights = [\"fold0\", \"fold1\", \"fold2\", \"fold3\", \"new_fold3\"]\n",
    "# names of the branches used to inference individual weights\n",
    "branches = [\"original\", \"original\", \"original\", \"original\", \"master\"]\n",
    "# indexes of labels whose confidences have to be *sqrt*ed\n",
    "sqrt_indexes = [4]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "K6wLjULbHbQd"
   },
   "source": [
    "# Get individual labels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "ybsUOSwNVIwE",
    "outputId": "72898a6c-6c9b-4870-b533-efe21a3e149c",
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "for index, weight in enumerate(weights):\n",
    "    print(index)\n",
    "    !cd GWC_YOLOv5 && git checkout {branches[index]} \n",
    "    !cd GWC_YOLOv5 && python detect.py --img-size 800 --name {weight} --weights ../{name}/{weight}.pt --source ../test --augment --nosave --save-txt --save-conf --conf-thres 0.50\n",
    "    !cp -r GWC_YOLOv5/runs/detect/{weight}/labels {weight}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "e95DRDldZhId"
   },
   "source": [
    "# Weighted boxes fusion\n",
    "Use [weighted boxes fusion](https://github.com/ZFTurbo/Weighted-Boxes-Fusion) to ensemble individual labels for further training."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "b6n6s_anZPTl"
   },
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "files = [os.listdir(folder) for folder in weights]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "o0ieGR0lZUMV"
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "\n",
    "def parse_box(box):\n",
    "    x, y, w, h = box\n",
    "    return max(0, x - w / 2), max(0, y - h / 2), min(1, x + w / 2), min(1, y + h / 2)\n",
    "\n",
    "\n",
    "def parse_file(file_name, current):\n",
    "    boxes = []\n",
    "    scores = []\n",
    "    labels = []\n",
    "    with open(file_name) as f:\n",
    "        for line in f:\n",
    "            cls, *xywh, conf = list(map(float, line.split()))\n",
    "            boxes.append(parse_box(xywh))\n",
    "            if current in sqrt_indexes:\n",
    "                scores.append(np.sqrt(conf))\n",
    "            else:\n",
    "                scores.append(conf)\n",
    "            labels.append(0)\n",
    "    return boxes, scores, labels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "O3ANa3loZWRW",
    "outputId": "3f620998-9ad7-4ded-f7d6-66b0976a13e9",
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import shutil\n",
    "\n",
    "from ensemble_boxes import weighted_boxes_fusion\n",
    "from tqdm import tqdm\n",
    "\n",
    "if os.path.exists(\"labels\"):\n",
    "    shutil.rmtree(\"labels\")\n",
    "os.mkdir(\"labels\")\n",
    "\n",
    "for img_name in tqdm(os.listdir(\"test\")):\n",
    "    txt_name = img_name[:-4] + \".txt\"\n",
    "    boxes_list = [[] for _ in range(len(folders))]\n",
    "    scores_list = [[] for _ in range(len(folders))]\n",
    "    labels_list = [[] for _ in range(len(folders))]\n",
    "    weights_list = [1 for _ in range(len(folders))]\n",
    "    for current, folder in enumerate(folders):\n",
    "        if txt_name in files[current]:\n",
    "            boxes, scores, labels = parse_file(os.path.join(folder, txt_name), current)\n",
    "            boxes_list[current] = boxes\n",
    "            scores_list[current] = scores\n",
    "            labels_list[current] = labels\n",
    "    final_boxes, final_scores, _ = weighted_boxes_fusion(\n",
    "        boxes_list, scores_list, labels_list, weights=weights_list\n",
    "    )\n",
    "\n",
    "    with open(\"labels/\" + txt_name, \"w\") as f:\n",
    "        for box in final_boxes:\n",
    "            x1, y1, x2, y2 = box\n",
    "            x_center = (x1 + x2) / 2\n",
    "            y_center = (y1 + y2) / 2\n",
    "            width = x2 - x1\n",
    "            height = y2 - y1\n",
    "            temp = \" \".join(\n",
    "                [\"0\", str(x_center), str(y_center), str(width), str(height)]\n",
    "            )\n",
    "            f.write(temp)\n",
    "            f.write(\"\\n\")\n",
    "\n",
    "assert len(os.listdir(\"test\")) == len(os.listdir(\"labels\"))"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "collapsed_sections": [],
   "name": "Copy of Get Labels.ipynb",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
